This repository was archived by the owner on Jan 21, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 256
/
Copy pathutils.py
3079 lines (2710 loc) · 123 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2021 The Mesh TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Utilities for running training and inference.
The `run` function for training the Transformer model is defined in this file.
TODO(katherinelee): add details about gin.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import itertools
import math
import os
import random
import re
import time
import gin
import gin.tf
import mesh_tensorflow as mtf
from mesh_tensorflow.transformer import dataset as transformer_dataset
from mesh_tensorflow.transformer import learning_rate_schedules
from mesh_tensorflow.transformer import transformer
import numpy as np
import pkg_resources
import six
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from tensorflow.core.protobuf import rewriter_config_pb2 # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import resources # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.tpu import tpu_config # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import
try:
tf.flags.DEFINE_multi_string("gin_file", None, "Path to a Gin file.")
tf.flags.DEFINE_multi_string("gin_param", None, "Gin parameter binding.")
tf.flags.DEFINE_list("gin_location_prefix", [], "Gin file search path.")
except tf.flags.DuplicateFlagError:
pass
FLAGS = tf.flags.FLAGS
_DEFAULT_CONFIG_FILE = "./gin/defaults.gin"
# List of features used by model.
_MODEL_FEATURES = [
"inputs", "inputs_position", "inputs_segmentation", "targets",
"targets_position", "targets_segmentation", "targets_subsegmentation"
]
def _filter_features(ex):
"""Filters example features, keeping only valid model features."""
return {k: v for k, v in ex.items() if k in _MODEL_FEATURES}
def parse_gin_defaults_and_flags(skip_unknown=False, finalize_config=True):
"""Parses all default gin files and those provided via flags."""
# Register .gin file search paths with gin
for gin_file_path in FLAGS.gin_location_prefix:
gin.add_config_file_search_path(gin_file_path)
# Set up the default values for the configurable parameters. These values will
# be overridden by any user provided gin files/parameters.
gin.parse_config_file(
pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE),
skip_unknown=skip_unknown)
gin.parse_config_files_and_bindings(
FLAGS.gin_file, FLAGS.gin_param,
skip_unknown=skip_unknown,
finalize_config=finalize_config)
# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
# this stupid VariableDtype class and stop passing it all over creation.
@gin.configurable
def get_variable_dtype(
master_dtype=tf.bfloat16,
slice_dtype=tf.float32,
activation_dtype=tf.float32):
"""Datatypes to use for the run.
Args:
master_dtype: string, datatype for checkpoints
keep this the same between training and eval/inference
slice_dtype: string, datatype for variables in memory
must be tf.float32 for training
activation_dtype: string, datatype for activations
less memory usage if tf.bfloat16 but possible numerical issues
Returns:
a mtf.VariableDtype
"""
return mtf.VariableDType(
master_dtype=tf.as_dtype(master_dtype),
slice_dtype=tf.as_dtype(slice_dtype),
activation_dtype=tf.as_dtype(activation_dtype))
def inputs_vocabulary(vocabulary):
"""Get the inputs vocabulary.
Args:
vocabulary: Vocabulary or (inputs_vocabulary, targets_vocabulary) tuple.
Returns:
a Vocabulary
"""
if isinstance(vocabulary, tuple):
vocabulary = vocabulary[0]
return vocabulary
def targets_vocabulary(vocabulary):
"""Get the targets vocabulary.
Args:
vocabulary: Vocabulary or (inputs_vocabulary, targets_vocabulary) tuple.
Returns:
a Vocabulary
"""
if isinstance(vocabulary, tuple):
vocabulary = vocabulary[1]
return vocabulary
@gin.configurable
def separate_vocabularies(inputs=gin.REQUIRED, targets=gin.REQUIRED):
"""Gin-configurable helper function to generate a tuple of vocabularies."""
return (inputs, targets)
@gin.configurable
def init_checkpoint_variable_mapping(name, mapping_fn=None):
"""Maps from variable name in graph to variable name in checkpoint."""
if mapping_fn:
return mapping_fn(name)
else:
return name
@gin.configurable
def should_load_variable(name, filter_fn=None):
"""Determines whether a global variable should be loaded from a ckpt."""
if filter_fn:
return filter_fn(name)
else:
return True
# TODO(katherinelee): Update layout_rules string when noam updates the
# definition in run
def build_model(model_type="bitransformer",
input_vocab_size=gin.REQUIRED,
output_vocab_size=gin.REQUIRED,
layout_rules=None,
mesh_shape=None):
"""Build a transformer model.
Currently, four types of models are supported:
"bitransformer": The traditional encoder-decoder architecture from
"Attention is All You Need". Requires a non-text2self dataset.
"lm": an autoregressive language model (one layer stack). Effectively the
decoder of the bitransformer. There is no attention over the encoder, since
there is no encoder. Requires a text2self dataset, with targets, but no
inputs.
"delimited_lm": an autoregressive language model trained on a text2text
dataset. Each training example is expressed as
[<input_tokens>, EOS, <target_tokens>, EOS]. Model checkpoints are
compatible with "lm" models. One strategy is to pretrain as "lm"
then fine-tune as "delimited_lm".
"aligned": a non-autoregressive single-stack model (like BERT). Requires
a non-text2self dataset with inputs and targets. The targets and inputs
have the same length and each entry in the inputs is aligned to the
corresponding entry in targets, eg:
"inputs": "The X sat on X X."
'targets": "The cat sat on the mat."
(except, inputs are token ID sequences, not strings)
"bi_teacher_student": a teacher-student model where both the student and
teacher are bitransformers. Requires a non-text2self dataset.
A text2self dataset has targets that are offset of the inputs. Non-text2self
datasets have targets that differ from their inputs, like:
input: 'hello'
target: 'bonjour'
Args:
model_type: a string, one of "bitransformer", "lm", "delimited_lm",
"aligned", or "bi_teacher_student"
input_vocab_size: an integer
output_vocab_size: an integer
layout_rules: optional, input to mtf.convert_to_layout_rules
mesh_shape: optional, an input to mtf.convert_to_shape()
Returns:
a Unitransformer or Bitransformer
"""
if model_type == "bitransformer":
return transformer.make_bitransformer(
input_vocab_size=input_vocab_size,
output_vocab_size=output_vocab_size,
mesh_shape=mesh_shape,
layout=layout_rules)
elif model_type == "bi_student_teacher":
return transformer.make_bi_student_teacher(
input_vocab_size=input_vocab_size,
output_vocab_size=output_vocab_size,
mesh_shape=mesh_shape,
layout=layout_rules)
elif model_type in ["lm", "delimited_lm", "aligned"]:
return transformer.Unitransformer(
autoregressive=model_type in ["lm", "delimited_lm"],
layer_stack=transformer.make_layer_stack(),
input_vocab_size=input_vocab_size,
output_vocab_size=output_vocab_size,
mesh_shape=mesh_shape,
layout=layout_rules)
else:
raise ValueError("unknown model_type")
@gin.configurable
def tpu_mesh_shape(tpu_topology=gin.REQUIRED,
model_parallelism=gin.REQUIRED,
ensemble_parallelism=None):
"""Create a mesh_shape for data-parallelism and model-parallelism on TPU.
Example: tpu_mesh_shape("4x4", 8) -> mtf.Shape(("batch", 4), ("model", 8))
Since there are 4x4x2=32 total cores, and we want 8-way model paralleism.
This function is passed through gin to the argument `mesh_shape` inside the
function `run`.
Alternatively, for model_parallelism, pass a mesh_spec (see simd_mesh_impl.py)
TODO(noam): describe
Args:
tpu_topology: a string - e.g. "2x2" or "v3-8"
model_parallelism: an integer - the number of cores per model replica
alternatively a list that can be passed to
simd_mesh_impl.HierarchicalTiling
ensemble_parallelism: an optional integer - if present then create an
"ensemble" mesh-dimension as well, for splitting the models in an
ensemble.
Returns:
a mtf.Shape
"""
if tpu_topology.startswith("v"):
num_cores = int(tpu_topology.split("-")[-1])
else:
# check for twisted topologies
tpu_topology = re.split("_twisted|_untwisted", tpu_topology)[0]
tpu_dim = [int(x) for x in tpu_topology.split("x")]
num_cores = functools.reduce(lambda x, y: x * y,
tpu_dim) * FLAGS.logical_cores_per_chip
if isinstance(model_parallelism, list):
# model_parallelism is actually a spec used to
# construct a simd_mesh_impl.HierarchicalTiling object
return mtf.simd_mesh_impl.HierarchicalTiling.spec_to_mesh_shape(
model_parallelism, num_cores)
data_parallelism = num_cores // model_parallelism
if ensemble_parallelism:
data_parallelism //= ensemble_parallelism
dims = []
if ensemble_parallelism and ensemble_parallelism > 1:
dims.append(mtf.Dimension("ensemble", ensemble_parallelism))
if data_parallelism > 1:
dims.append(mtf.Dimension("batch", data_parallelism))
if model_parallelism > 1:
dims.append(mtf.Dimension("model", model_parallelism))
return mtf.Shape(dims)
@gin.configurable
def variable_filter_max_size(v, max_size=1e7):
return v.size <= max_size
def _build_ckpt_to_local_var_name_mapping(
ckpt_num_blocks, ckpt_num_layers, local_num_blocks,
local_num_layers, new_layers, regex_prefix=None):
"""Builds a mapping from checkpoint variable names to local variable names.
Args:
ckpt_num_blocks: an integer, number of blocks in checkpoint.
ckpt_num_layers: an integer, number of layers in checkpoint.
local_num_blocks: an integer, number of blocks in current model.
local_num_layers: an integer, number of layers in current model.
new_layers: a list of lists, specifying new layer indices in the current
model not present in the ckpt.
regex_prefix: optional, a string, specifying a prefix to match for
both checkpoint variables and ones in the current model.
Returns:
a dictionary where keys are checkpoint variable name regexes and
values are local variable name regexes. It specifies the mapping between
the checkpoint block/layer group and local block/layer group.
"""
def build_regex(layer_num, block_num, num_blocks):
base_regex = r"layer_{:0=3d}".format(layer_num)
if num_blocks is not None:
base_regex = r"block_{:0=3d}/".format(block_num) + base_regex
if regex_prefix is not None:
base_regex = regex_prefix + r".*" + base_regex
return base_regex
all_ckpt_name_regexes = []
for block_num in range(ckpt_num_blocks or 1):
for layer_num in range(ckpt_num_layers):
all_ckpt_name_regexes.append(
build_regex(layer_num, block_num, ckpt_num_blocks))
all_local_name_regexes = []
for block_num in range(local_num_blocks or 1):
for layer_num in range(local_num_layers):
# Skip the new layers in the mapping ordering.
if (new_layers is not None) and (layer_num in new_layers): continue
all_local_name_regexes.append(
build_regex(layer_num, block_num, local_num_blocks))
if len(all_ckpt_name_regexes) != len(all_local_name_regexes):
raise ValueError("Invalid checkpoint to load. Number of variables in ckpt "
"and current model (minus `new_layers`) must be equal.")
# Build a mapping from ckpt var regex to local var regex.
ckpt_var_name_to_local_var_name = {}
for ckpt_var_name, local_var_name in zip(
all_ckpt_name_regexes, all_local_name_regexes):
ckpt_var_name_to_local_var_name[ckpt_var_name] = local_var_name
return ckpt_var_name_to_local_var_name
def _match_ckpt_to_local_var_name(
ckpt_var_name, local_var_name, ckpt_var_name_to_local_var_name):
"""Returns True if this pair of vars should be loaded, False otherwise."""
# Name does not fall into the block/layer convention, so return identity.
# This will cover variable such as the global_step, embeddings, etc...
if "layer" not in ckpt_var_name or "layer" not in local_var_name:
return ckpt_var_name == local_var_name
# If the variable suffixes do not match they cannot be matched.
if ckpt_var_name.split("/")[-1] != local_var_name.split("/")[-1]:
return False
for ckpt_regex, var_regex in ckpt_var_name_to_local_var_name.items():
if (re.match(ckpt_regex, ckpt_var_name) and
re.match(var_regex, local_var_name)):
# Both ckpt and local var are the same layer/block group. Now check to
# see if its the same parameter in the layer/block group.
if ".*" in ckpt_regex:
ckpt_regex = ckpt_regex.split(".*")[1]
if ".*" in var_regex:
var_regex = var_regex.split(".*")[1]
if ckpt_var_name.replace(ckpt_regex, var_regex) == local_var_name:
return True
return False
def _compute_num_blocks_and_layer(var_names):
"""Takes list of variable names and outputs the max number of blocks/layers."""
encoder_decoder_model = any(
[re.match(r"^encoder/", var_name) for var_name in var_names])
def get_max_layer_or_block_num(regex, var_names):
matched_nums = [re.findall(regex, v) for v in var_names]
return max(
[int(num) + 1 for num in list(itertools.chain(*matched_nums))] + [-1])
if encoder_decoder_model:
enc_max_layer_num = get_max_layer_or_block_num(
r"encoder/.*layer_(\d{3})", var_names)
enc_max_block_num = get_max_layer_or_block_num(
r"encoder/.*block_(\d{3})", var_names)
dec_max_layer_num = get_max_layer_or_block_num(
r"decoder/.*layer_(\d{3})", var_names)
dec_max_block_num = get_max_layer_or_block_num(
r"decoder/.*block_(\d{3})", var_names)
max_layer_num = [enc_max_layer_num, dec_max_layer_num]
max_block_num = [enc_max_block_num, dec_max_block_num]
else:
max_layer_num = [get_max_layer_or_block_num(r"layer_(\d{3})", var_names)]
max_block_num = [get_max_layer_or_block_num(r"block_(\d{3})", var_names)]
max_block_num = [n if (n != -1) else None for n in max_block_num]
return max_block_num, max_layer_num
@gin.configurable
def flexible_ckpt_init_mapping(ckpt_path="", new_layers=None,
return_mapping_fn=True):
"""More flexibly handles loading a checkpoint.
To be used as a mapping_fn in init_checkpoint_variable_mapping or as a
filter_fn in should_load_variable depending on return_mapping_fn.
Covers three common cases when initializing a checkpoint:
(1) Loading a checkpoint that contains different block/layer numbering.
(2) Inserting new layers in into the current model that should not be loaded
from the checkpoint (e.g. inserting an extra DenseReLUDense layer in each
block group for the current model).
(3) Changing the layer type from the ckpt in the current model
(e.g. replacing a DenseReLUDense layer with an MoE1D layer).
Args:
ckpt_path: string, saved checkpoint path to load.
new_layers: optional list of lists specifing what numbers in the layer stack
are newly added in the current model. These should be skipped when loading
the checkpoint weights. If Enc-Dec model the list will contains two lists,
one for new encoder layers and the other for decoder layer
(e.g. [[3], [1]]). If LM then just a single list (e.g. [[3]]).
return_mapping_fn: a boolean, if True then return a function mapping from
the graph variable names to the checkpoint variable names that should be
loaded in. If False, then return a filter_fn that will return whether
a graph variable should be loaded from the ckpt.
Returns:
if return_mapping_fn is True then return a function mapping from
the graph variable names to the checkpoint variable names.
If False, then return a filter_fn that will return whether a graph
variable should be loaded from the ckpt.
"""
tf.logging.info("Using flexible_ckpt_init_mapping.")
ckpt_var_names = [v for v, _ in tf.train.list_variables(ckpt_path)]
local_var_names = [v.op.name for v in tf.global_variables()]
# `num_blocks` and `num_layers` will be tuples of length two for
# encoder-decoder models and length 1 for LMs.
ckpt_num_blocks, ckpt_num_layers = _compute_num_blocks_and_layer(
ckpt_var_names)
local_num_blocks, local_num_layers = _compute_num_blocks_and_layer(
local_var_names)
# Create regex mapping from ckpt variable names to local variable names.
mappings = []
if len(ckpt_num_blocks) == 2:
# Encoder-Decoder Model.
new_enc_layers, new_dec_layers = None, None
if new_layers is not None:
new_enc_layers, new_dec_layers = new_layers
enc_mapping = _build_ckpt_to_local_var_name_mapping(
ckpt_num_blocks[0], ckpt_num_layers[0], local_num_blocks[0],
local_num_layers[0], new_enc_layers, regex_prefix="encoder/")
dec_mapping = _build_ckpt_to_local_var_name_mapping(
ckpt_num_blocks[1], ckpt_num_layers[1], local_num_blocks[1],
local_num_layers[1], new_dec_layers, regex_prefix="decoder/")
mappings = [enc_mapping, dec_mapping]
else:
# LM Model.
new_lm_layers = None
if new_layers is not None:
new_lm_layers = new_layers[0]
lm_mapping = _build_ckpt_to_local_var_name_mapping(
ckpt_num_blocks[0], ckpt_num_layers[0], local_num_blocks[0],
local_num_layers[0], new_lm_layers)
mappings = [lm_mapping]
graph_var_to_ckpt_var = {}
for ckpt_var_name in ckpt_var_names:
for local_var_name in local_var_names:
for mapping in mappings:
if _match_ckpt_to_local_var_name(
ckpt_var_name, local_var_name, mapping):
graph_var_to_ckpt_var[local_var_name] = ckpt_var_name
def mapping_fn(var_name):
return graph_var_to_ckpt_var[var_name]
def filter_fn(var_name):
return var_name in graph_var_to_ckpt_var
if return_mapping_fn:
return mapping_fn
else:
return filter_fn
@gin.configurable(denylist=["predict_fn"]) # pass `predict_fn` through `run`
def tpu_estimator_model_fn(model_type,
transformer_model,
vocabulary,
model_dir,
use_tpu,
mesh_shape,
layout_rules,
batch_size,
sequence_length,
autostack,
keep_checkpoint_max,
save_checkpoints_steps,
learning_rate_schedule=None,
optimizer=None,
outer_batch_size=1,
tpu_summaries=False,
predict_fn=None,
score_in_predict_mode=False,
variable_filter=None,
init_checkpoint=None,
init_variable_filter="",
ensemble_inputs=None,
mesh_devices=None,
model_info_file=None,
hierarchical_tiling_spec=None):
"""Create a TPUEstimator model function.
Args:
model_type: a string. One of "bitransformer", "lm", "delimited_lm",
"aligned", or "bi_teacher_student"
transformer_model: a transformer.Unitransformer or transformer.Bitransformer
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple. Used for decoding in predict mode.
model_dir: a string, directory to save the model to.
use_tpu: a boolean
mesh_shape: a mtf.Shape
layout_rules: a mtf.LayoutRules
batch_size: an integer
sequence_length: an integer or a dict from feature-key to integer
the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
autostack: a boolean
keep_checkpoint_max: an integer, maximum number of checkpoints to keep
save_checkpoints_steps: an integer, save a checkpoint every this number of
steps
learning_rate_schedule: a constant or a function from step to learning rate
optimizer: a class extending optimize.Optimizer, required for training
outer_batch_size: outer batch dimension that could be used to enable the mix
of data-parallel and model-parallel training of Mixture of Experts (MoE)
models
tpu_summaries: a boolean, use rewrites to make summaries work on TPU. This
may be slow, since it uses a host call hack.
predict_fn: an optional function, see docs for `run` for more information.
score_in_predict_mode: compute log-likelihood scores instead of predictions
variable_filter: controls which variables are trained.
If None (default), train all trainable variables.
If a string regex, train all variables that match this regex.
If a function (mtf.Variable -> boolean), then train variables for which
the function returns True.
init_checkpoint: a string, if not None then read in variables from this
checkpoint path when initializing variables. Will only initialize
variables that appear both in the current graph and the checkpoint.
init_variable_filter: a string, used only when init_checkpoint is set.
controls which variables are loaded from the checkpoint using regex.
if empty string (default), all variables from the checkpoint are loaded.
ensemble_inputs: an optional integer - pass the size of the ensemble to
train an ensemble where each model gets different inputs.
You also need to configure Unitransformer.ensemble to the right size.
If None, then all models are trained on the same inputs.
mesh_devices: a list of strings, the device names to use for each mesh
slice. Only required for GPU.
model_info_file: an optional string, information about variables and
operations will be logged to this file during the TRAIN mode.
hierarchical_tiling_spec: an optional list that can be passed as the
spec argument to simd_mesh_impl.HierarchicalTiling
Returns:
a function to be passed to TPUEstimator
"""
mesh_devices = mesh_devices or [""] * mesh_shape.size
def my_model_fn(features, labels, mode, params=None, config=None):
"""Estimator model function.
Args:
features: dictionary where keys are strings like "inputs" and "targets"
and the values are the actual values of "inputs". See TPUEstimator's
docs for more information
labels: ignored argument
mode: a tf.estimator.ModeKeys
params: dictionary containing the key "context"
config: ignored argument
Returns:
a TPUEstimatorSpec
"""
del labels, config
if mode == tf.estimator.ModeKeys.PREDICT and score_in_predict_mode:
mode = "score"
global_step = tf.train.get_global_step()
if use_tpu and "context" in params:
ctx = params["context"]
num_hosts = ctx.num_hosts
host_placement_fn = ctx.tpu_host_placement_function
device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
# TODO(ylc): Better estimation of replica cache size?
replica_cache_size = 300 * 1000000 # 300M per replica
# Worker 0 caches all the TPU binaries.
worker0_mem = replica_cache_size * ctx.num_replicas
devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
var_placer = mtf.utils.BalancedVariablePlacer(device_list,
devices_memeory_usage)
physical_shape = [int(i) for i in
params["context"].device_assignment.topology.mesh_shape]
mesh_4d = False
if len(physical_shape) == 4:
mesh_4d = True if physical_shape[2] > 1 else False
physical_shape = (
mtf.simd_mesh_impl.physical_shape_3d_from_topology_proto_4d(
physical_shape))
if mesh_4d or hierarchical_tiling_spec is None:
logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
mesh_shape.to_integer_list,
physical_shape,
device_assignment=params["context"].device_assignment)
else:
logical_to_physical = mtf.simd_mesh_impl.HierarchicalTiling(
hierarchical_tiling_spec,
physical_shape).logical_to_physical
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
mesh_shape, layout_rules, mesh_devices, ctx.device_assignment,
logical_to_physical=logical_to_physical)
else:
var_placer = None
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, layout_rules, mesh_devices)
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh", var_placer)
if (outer_batch_size and
mode not in [tf.estimator.ModeKeys.PREDICT, "score"]):
outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
batch_dims = [outer_batch_dim, batch_dim]
else:
batch_dim = mtf.Dimension("batch", batch_size)
batch_dims = [batch_dim]
ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)]
if ensemble_inputs else [])
predict_batch_size = features.pop("predict_batch_size", None)
mtf_features = {}
for key, x in features.items():
# Some auxiliary features may have been generated in packing.
# The names of these new features are of the form
# "<original_feature_name>_<suffix>", e.g. "inputs_segmentation".
# We look up the lengths based on the original feature name, without
# the "_<suffix>".
feature_length = sequence_length[key.split("_")[0]]
length_dim = mtf.Dimension("length", feature_length)
feature_shape = mtf.Shape(
ensemble_dims + batch_dims + [length_dim])
x = tf.cast(features[key], tf.int32)
x = tf.reshape(x, feature_shape.to_integer_list)
if not use_tpu:
tf.logging.info("feature %s : %s" % (key, x))
mtf_features[key] = mtf.import_fully_replicated(
mesh, x, feature_shape, name=key)
def _verify_feature_exists(feature_name, should_exist):
if should_exist != (feature_name in mtf_features):
message = (
"mode=%s model_type=%s should%s have feature %s" %
(mode, model_type, "" if should_exist else " not", feature_name))
if "lm" in model_type:
message += (
"\nA common mistake is that model_type=\"delimited_lm\" should "
"be used with tasks that produce inputs and targets, while "
"model_type=\"lm\" should be used with tasks that produce "
"targets only.")
raise ValueError(message)
# Verify that the right features exist, and transform them if necessary
if mode == tf.estimator.ModeKeys.PREDICT:
_verify_feature_exists("inputs", True)
# "targets" may or may not exist depending on whether we are doing
# evaluation or open-ended inference.
elif model_type in ("lm", "delimited_lm") and mode == "score":
# in scoring mode the inputs and targets may already be combined.
if "inputs" in mtf_features:
if model_type == "lm":
tf.logging.warning(
"Scoring of lm models will include loss from the 'inputs'.")
mtf_features = _dynamic_text2self(mtf_features)
else:
_verify_feature_exists("targets", True)
_verify_feature_exists("inputs", model_type != "lm")
if model_type == "delimited_lm":
mtf_features = _dynamic_text2self(mtf_features)
# Detokenize in the graph if supported by vocabulary and accelerator.
def _maybe_detokenize(ids, vocab):
if not use_tpu and hasattr(vocab, "decode_tf"):
return vocab.decode_tf(ids)
return ids
if mode == "score":
# compute log-likelihoods per sequence
targets = mtf_features["targets"]
if predict_fn:
# predict_fn contains a custom scoring function
scores = predict_fn(
model=transformer_model,
features=mtf_features,
variable_dtype=get_variable_dtype())
else:
if isinstance(transformer_model, transformer.Unitransformer):
length_dim = targets.shape.dims[-1]
inputs = transformer.autoregressive_inputs(
mtf_features["targets"])
elif isinstance(transformer_model,
(transformer.Bitransformer,
transformer.StudentTeacher)):
inputs = mtf_features["inputs"]
else:
raise ValueError("unrecognized class")
logits, _ = transformer_model.call_simple(
inputs=inputs,
targets=targets,
compute_loss=False,
mode=mode,
variable_dtype=get_variable_dtype())
logits = mtf.cast(logits, tf.float32)
_, length_dim, vocab_dim = logits.shape.dims
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
logits, mtf_features["targets"], vocab_dim)
# 0=padding and negative targets are a hack to indicate no loss
cross_entropy *= mtf.cast(
mtf.greater(targets, 0), cross_entropy.dtype)
if model_type == "delimited_lm":
cross_entropy *= mtf.cast(mtf.logical_not(
transformer.delimited_lm_inputs_mask(targets)),
cross_entropy.dtype)
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
scores = mtf.anonymize(scores)
targets = mtf.anonymize(targets)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
targets = clean_decodes(lowering.export_to_tf_tensor(targets))
targets = _maybe_detokenize(targets, targets_vocabulary(vocabulary))
predictions = {
"targets": targets,
"scores": lowering.export_to_tf_tensor(scores)
}
elif mode == tf.estimator.ModeKeys.PREDICT:
inputs = mtf_features["inputs"]
if predict_fn:
mtf_samples = predict_fn(
model=transformer_model,
features=mtf_features,
variable_dtype=get_variable_dtype())
elif isinstance(transformer_model, transformer.Unitransformer):
# pad so that there is enough room for the targets
inputs = mtf.pad(
inputs, [0, sequence_length["targets"]], length_dim.name)
mtf_samples = transformer_model.sample_autoregressive(
inputs, variable_dtype=get_variable_dtype(),
remove_partial_sequences=True)
elif isinstance(
transformer_model,
(transformer.Bitransformer, transformer.StudentTeacher)):
mtf_samples = transformer_model.decode(
inputs, variable_dtype=get_variable_dtype())
else:
raise ValueError("unrecognized class")
mtf_samples = mtf.anonymize(mtf_samples)
inputs = mtf.anonymize(inputs)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
inputs = clean_decodes(lowering.export_to_tf_tensor(inputs))
outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples))
inputs = _maybe_detokenize(inputs, inputs_vocabulary(vocabulary))
outputs = _maybe_detokenize(outputs, targets_vocabulary(vocabulary))
if predict_batch_size is not None:
inputs = inputs[:predict_batch_size]
outputs = outputs[:predict_batch_size]
predictions = {
"inputs": inputs,
"outputs": outputs}
if mode in ["score", tf.estimator.ModeKeys.PREDICT]:
# When exporting a model, we need to communicate to TF-Serving that
# master variables need to be copied to their slave slice variables.
# Estimator uses a Scaffold's "local_init_op" for this purpose, so we
# augment the default "local_init_op" here.
#
# The "ready_op" is also constructed here to ensure the variables
# initialized by "local_init_op" are the same ones checked by "ready_op".
#
# WARNING: Any variables created outside of this model_fn()
# (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor
# checked by these ops.
def scaffold_fn():
return tf.train.Scaffold(
local_init_op=tf.group(
tf.train.Scaffold.default_local_init_op(),
lowering.copy_masters_to_slices(),
name="mtf_local_init_op"),
ready_op=tf.concat(
[tf.report_uninitialized_variables(),
resources.report_uninitialized_resources()],
axis=0,
name="mtf_ready_op"))
return tpu_estimator.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
scaffold_fn=scaffold_fn,
prediction_hooks=[mtf.MtfRestoreHook(lowering)])
assert (mode == tf.estimator.ModeKeys.TRAIN or
mode == tf.estimator.ModeKeys.EVAL)
def logits_and_loss(mtf_features, num_microbatches=1):
"""Compute logits and loss.
Args:
mtf_features: a dictionary
num_microbatches: integer
Returns:
logits: a mtf.Tensor
loss: a mtf.Tensor
"""
if model_type in ["lm", "delimited_lm"]:
inputs = transformer.autoregressive_inputs(
mtf_features["targets"],
sequence_id=mtf_features.get("targets_segmentation", None))
else:
inputs = mtf_features["inputs"]
if isinstance(transformer_model, transformer.Unitransformer):
position_kwargs = dict(
sequence_id=mtf_features.get("targets_segmentation", None),
position=mtf_features.get("targets_position", None),
)
elif isinstance(
transformer_model,
transformer.Bitransformer) or model_type == "bi_student_teacher":
position_kwargs = dict(
encoder_sequence_id=mtf_features.get("inputs_segmentation", None),
decoder_sequence_id=mtf_features.get("targets_segmentation",
None),
decoder_subsequence_id=mtf_features.get("targets_subsegmentation",
None),
encoder_position=mtf_features.get("inputs_position", None),
decoder_position=mtf_features.get("targets_position", None),
)
else:
raise ValueError("unrecognized class")
return transformer_model.call_simple(
inputs=inputs,
targets=mtf_features["targets"],
compute_loss=True,
mode=mode,
variable_dtype=get_variable_dtype(),
num_microbatches=num_microbatches,
**position_kwargs)
if mode == tf.estimator.ModeKeys.TRAIN:
num_microbatches = serialize_num_microbatches(batch_dim,
sequence_length,
mesh_shape,
layout_rules)
if num_microbatches > 1:
def serialized_fn(mtf_features):
return {"loss": logits_and_loss(mtf_features, num_microbatches)[1]}
var_grads, loss_dict = mtf.serialize_training_step(
mtf_features, serialized_fn, batch_dim, num_microbatches)
loss = loss_dict["loss"]
else:
loss = logits_and_loss(mtf_features)[1]
var_grads = mtf.gradients(
[loss], [v.outputs[0] for v in graph.trainable_variables])
if tpu_summaries:
mtf.scalar_summary("loss", loss)
if callable(learning_rate_schedule):
# the following happens on CPU since TPU can't handle summaries.
with mtf.utils.outside_all_rewrites():
learning_rate = learning_rate_schedule(
step=tf.train.get_global_step())
tf.summary.scalar("learning_rate", learning_rate)
else:
learning_rate = learning_rate_schedule
if isinstance(variable_filter, str):
pattern = re.compile(variable_filter)
variable_filter_fn = lambda v: pattern.search(v.name)
elif variable_filter is None:
variable_filter_fn = lambda v: True
elif callable(variable_filter):
variable_filter_fn = variable_filter
else:
raise ValueError(
"variable_filter must be None, a string, or a callable function")
trainable_vars = [
v for v in graph.trainable_variables if variable_filter_fn(v)]
trainable_var_grads = [
g for g, v in zip(var_grads, graph.trainable_variables)
if variable_filter_fn(v)]
if len(trainable_vars) != len(graph.trainable_variables):
tf.logging.info("Variables being trained:")
tf.logging.info([v.name for v in trainable_vars])
tf.logging.info("Variables not being trained:")
tf.logging.info([v.name for v in graph.trainable_variables
if not variable_filter_fn(v)])
update_ops = optimizer(learning_rate=learning_rate).apply_grads(
trainable_var_grads, trainable_vars
)
lowering = mtf.Lowering(
graph, {mesh: mesh_impl},
autostack=autostack,
log_file=model_info_file)
tf_loss = lowering.export_to_tf_tensor(loss)
tf_loss = tf.cast(tf_loss, tf.float32)
if not use_tpu:
tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()],
"step, tf_loss")
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
tf_update_ops.append(tf.assign_add(global_step, 1))
train_op = tf.group(tf_update_ops)
if hasattr(transformer_model, "initialize"):
with mtf.utils.outside_all_rewrites():
transformer_model.initialize()
if tpu_summaries:
# has to be outside of
# with mtf.utils.outside_all_rewrites()
host_call = mtf.utils.create_host_call(model_dir)
mtf.utils.remove_summaries()
else:
host_call = None
with mtf.utils.outside_all_rewrites():
if init_checkpoint:
ckpt_vars = {v for v, _ in tf.train.list_variables(init_checkpoint)}
if init_variable_filter:
pattern = re.compile(init_variable_filter)
ckpt_vars = {v for v in ckpt_vars if pattern.search(v)}
global_vars = {v.op.name for v in tf.global_variables()}
filtered_global_vars = {
v for v in global_vars if should_load_variable(v)}
restore_vars = {
v for v in filtered_global_vars
if init_checkpoint_variable_mapping(v) in ckpt_vars}
tf.logging.info("Initializing variables from %s:", init_checkpoint)
tf.logging.debug("\n".join(sorted(restore_vars)))
tf.logging.info("Variables in %s but not in graph:", init_checkpoint)
tf.logging.info("\n".join(sorted(
ckpt_vars -
{init_checkpoint_variable_mapping(v)
for v in filtered_global_vars})))
tf.logging.info("Variables in graph but not in %s:", init_checkpoint)
tf.logging.info("\n".join(sorted(global_vars - restore_vars)))
tf.train.init_from_checkpoint(
init_checkpoint,
{init_checkpoint_variable_mapping(v): v for v in restore_vars}
)
# Copy master variables to slices. Must be called first.
restore_hook = mtf.MtfRestoreHook(lowering)
saver = tf.train.Saver(
tf.global_variables(),
sharded=True,
max_to_keep=keep_checkpoint_max,
keep_checkpoint_every_n_hours=2,
defer_build=False,
save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
saver_listener = mtf.MtfCheckpointSaverListener(lowering)
saver_hook = tf.train.CheckpointSaverHook(
model_dir,
save_steps=save_checkpoints_steps,
saver=saver,
listeners=[saver_listener])
gin_config_saver_hook = gin.tf.GinConfigSaverHook(
model_dir, summarize_config=True, include_step_in_filename=False)