Skip to content

Commit d1c5a6e

Browse files
authored
[Embedding] Fix op dependency in init_from_checkpoint API. (#1012)
Signed-off-by: lightwang <[email protected]>
1 parent 9e30ab6 commit d1c5a6e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensorflow/python/training/checkpoint_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ def _set_checkpoint_initializer(variable,
443443
is_partitioned_ev = variable._save_slice_info is not None
444444
partition_id = variable._save_slice_info.var_offset[0] if is_partitioned_ev else 0
445445
partition_num = variable._save_slice_info.full_shape[0] if is_partitioned_ev else 1
446-
with ops.control_dependencies([variable._initializer_op]):
446+
restore_dependency = ops.get_collection(ops.GraphKeys.EMBEDDING_VARIABLE_RESTORE_DEPENDENCY)[0]
447+
with ops.control_dependencies(restore_dependency[variable._primary_handle]):
447448
rank = variable.initial_value.get_shape().rank - 1
448449
restore_op = gen_kv_variable_ops.kv_resource_import_v3(
449450
ckpt_file,

0 commit comments

Comments
 (0)