Skip to content

Commit 9e30ab6

Browse files
authored
[Embedding] Check the sharded property of tf.train.Saver. (#996)
Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent 93c69ad commit 9e30ab6

File tree

22 files changed

+76
-71
lines changed

22 files changed

+76
-71
lines changed

modelzoo/bst/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,10 +612,9 @@ def train(sess_config,
612612
hooks = []
613613
hooks.extend(input_hooks)
614614

615-
sharded_saver = tf_config != None
616615
scaffold = tf.train.Scaffold(
617616
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
618-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
617+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
619618

620619
stop_hook = tf.train.StopAtStepHook(last_step=steps)
621620
log_hook = tf.train.LoggingTensorHook(

modelzoo/dbmtl/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,9 @@ def train(sess_config,
527527
hooks = []
528528
hooks.extend(input_hooks)
529529

530-
sharded_saver = tf_config != None
531530
scaffold = tf.train.Scaffold(
532531
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
533-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
532+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
534533

535534
stop_hook = tf.train.StopAtStepHook(last_step=steps)
536535
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcn/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,9 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597-
sharded_saver = tf_config != None
598597
scaffold = tf.train.Scaffold(
599598
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
600-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
599+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
601600

602601
stop_hook = tf.train.StopAtStepHook(last_step=steps)
603602
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcnv2/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,9 @@ def train(sess_config,
610610
hooks = []
611611
hooks.extend(input_hooks)
612612

613-
sharded_saver = tf_config != None
614613
scaffold = tf.train.Scaffold(
615614
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
616-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
615+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
617616

618617
stop_hook = tf.train.StopAtStepHook(last_step=steps)
619618
log_hook = tf.train.LoggingTensorHook(

modelzoo/deepfm/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,9 @@ def train(sess_config,
472472
hooks = []
473473
hooks.extend(input_hooks)
474474

475-
sharded_saver = tf_config != None
476475
scaffold = tf.train.Scaffold(
477476
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
478-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
477+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
479478

480479
stop_hook = tf.train.StopAtStepHook(last_step=steps)
481480
log_hook = tf.train.LoggingTensorHook(

modelzoo/dien/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,10 +776,9 @@ def train(sess_config,
776776
hooks = []
777777
hooks.extend(input_hooks)
778778

779-
sharded_saver = tf_config != None
780779
scaffold = tf.train.Scaffold(
781780
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
782-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
781+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
783782

784783
stop_hook = tf.train.StopAtStepHook(last_step=steps)
785784
log_hook = tf.train.LoggingTensorHook(

modelzoo/din/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,9 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597-
sharded_saver = tf_config != None
598597
scaffold = tf.train.Scaffold(
599598
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
600-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
599+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
601600

602601
stop_hook = tf.train.StopAtStepHook(last_step=steps)
603602
log_hook = tf.train.LoggingTensorHook(

modelzoo/dlrm/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,9 @@ def train(sess_config,
507507
hooks = []
508508
hooks.extend(input_hooks)
509509

510-
sharded_saver = tf_config != None
511510
scaffold = tf.train.Scaffold(
512511
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
513-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
512+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
514513

515514
stop_hook = tf.train.StopAtStepHook(last_step=steps)
516515
log_hook = tf.train.LoggingTensorHook(

modelzoo/dssm/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,9 @@ def train(sess_config,
478478
hooks = []
479479
hooks.extend(input_hooks)
480480

481-
sharded_saver = tf_config != None
482481
scaffold = tf.train.Scaffold(
483482
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
484-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
483+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
485484

486485
stop_hook = tf.train.StopAtStepHook(last_step=steps)
487486
log_hook = tf.train.LoggingTensorHook(

modelzoo/esmm/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,10 +534,9 @@ def train(sess_config,
534534
hooks = []
535535
hooks.extend(input_hooks)
536536

537-
sharded_saver = tf_config != None
538537
scaffold = tf.train.Scaffold(
539538
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
540-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
539+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
541540

542541
stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
543542
log_hook = tf.train.LoggingTensorHook(

modelzoo/masknet/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,9 @@ def train(sess_config,
529529
hooks = []
530530
hooks.extend(input_hooks)
531531

532-
sharded_saver = tf_config != None
533532
scaffold = tf.train.Scaffold(
534533
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
535-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
534+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
536535

537536
stop_hook = tf.train.StopAtStepHook(last_step=steps)
538537
log_hook = tf.train.LoggingTensorHook(

modelzoo/mlperf/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,9 @@ def train(sess_config,
522522
hooks = []
523523
hooks.extend(input_hooks)
524524

525-
sharded_saver = tf_config != None
526525
scaffold = tf.train.Scaffold(
527526
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
528-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
527+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
529528

530529
stop_hook = tf.train.StopAtStepHook(last_step=steps)
531530
log_hook = tf.train.LoggingTensorHook(

modelzoo/mmoe/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,9 @@ def train(sess_config,
523523
hooks = []
524524
hooks.extend(input_hooks)
525525

526-
sharded_saver = tf_config != None
527526
scaffold = tf.train.Scaffold(
528527
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
529-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
528+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
530529

531530
stop_hook = tf.train.StopAtStepHook(last_step=steps)
532531
log_hook = tf.train.LoggingTensorHook(

modelzoo/ple/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,9 @@ def train(sess_config,
592592
hooks = []
593593
hooks.extend(input_hooks)
594594

595-
sharded_saver = tf_config != None
596595
scaffold = tf.train.Scaffold(
597596
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
598-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
597+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
599598

600599
stop_hook = tf.train.StopAtStepHook(last_step=steps)
601600
log_hook = tf.train.LoggingTensorHook(

modelzoo/simple_multitask/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,9 @@ def train(sess_config,
427427
hooks = []
428428
hooks.extend(input_hooks)
429429

430-
sharded_saver = tf_config != None
431430
scaffold = tf.train.Scaffold(
432431
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
433-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
432+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
434433

435434
stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
436435
log_hook = tf.train.LoggingTensorHook(

modelzoo/wide_and_deep/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,9 @@ def train(sess_config,
543543
hooks = []
544544
hooks.extend(input_hooks)
545545

546-
sharded_saver = tf_config != None
547546
scaffold = tf.train.Scaffold(
548547
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
549-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
548+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
550549

551550
stop_hook = tf.train.StopAtStepHook(last_step=steps)
552551
log_hook = tf.train.LoggingTensorHook(

tensorflow/python/feature_column/feature_column_v2_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7527,7 +7527,7 @@ def testEmbeddingVariableForL2FeatureEviction(self):
75277527
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
75287528
g_v = opt.compute_gradients(loss)
75297529
train_op = opt.apply_gradients(g_v)
7530-
saver = saver_module.Saver()
7530+
saver = saver_module.Saver(sharded=True)
75317531
init = variables_lib.global_variables_initializer()
75327532
with self.test_session() as sess:
75337533
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
@@ -7758,7 +7758,7 @@ def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum(self):
77587758
g_v = opt.compute_gradients(loss)
77597759
train_op = opt.apply_gradients(g_v)
77607760
init = variables_lib.global_variables_initializer()
7761-
saver = saver_module.Saver()
7761+
saver = saver_module.Saver(sharded=True)
77627762

77637763
@test_util.run_deprecated_v1
77647764
def testEmbeddingVariableForInt32ID(self):
@@ -7783,7 +7783,7 @@ def testEmbeddingVariableForInt32ID(self):
77837783
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
77847784
g_v = opt.compute_gradients(loss)
77857785
train_op = opt.apply_gradients(g_v)
7786-
saver = saver_module.Saver()
7786+
saver = saver_module.Saver(sharded=True)
77877787
init = variables_lib.global_variables_initializer()
77887788
with self.test_session() as sess:
77897789
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))

tensorflow/python/ops/embedding_variable_ops_gpu_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def testEmbeddingVariableForInitFromProto(self):
6363
g_v = opt.compute_gradients(loss)
6464
train_op = opt.apply_gradients(g_v)
6565
graph = ops.get_default_graph()
66-
meta_graph_def = saver_module.export_meta_graph()
66+
saver = saver_module.Saver(sharded=True)
67+
meta_graph_def = saver_module.export_meta_graph(saver_def=saver.as_saver_def())
6768
ops.reset_default_graph()
6869
with self.test_session() as sess:
6970
res = saver_module.import_meta_graph(meta_graph_def)
@@ -748,7 +749,7 @@ def testSaveV3(self):
748749
g_v = opt.compute_gradients(loss)
749750
train_op = opt.apply_gradients(g_v, global_step=gs)
750751
init = variables.global_variables_initializer()
751-
saver = saver = saver_module.Saver()
752+
saver = saver = saver_module.Saver(sharded=True)
752753
checkpoint_directory = self.get_temp_dir()
753754
model_path = os.path.join(checkpoint_directory, "model.ckpt")
754755
with self.test_session() as sess:
@@ -816,7 +817,7 @@ def testEmbeddingVariableSaveAndRestoreOptimzierStatesForMultiTierWithHbm(self):
816817
opt = adagrad.AdagradOptimizer(0.1)
817818
g_v = opt.compute_gradients(loss)
818819
train_op = opt.apply_gradients(g_v, gs)
819-
saver = saver_module.Saver()
820+
saver = saver_module.Saver(sharded=True)
820821
graph = ops.get_default_graph()
821822
with self.test_session(graph = graph) as sess:
822823
saver.restore(sess, os.path.join(checkpoint_directory, "model.ckpt-12345"))

0 commit comments

Comments
 (0)