@@ -2290,15 +2290,17 @@ def auto_train_steps(batch_size,
22902290
22912291@gin .configurable
22922292def get_checkpoint_iterator (checkpoint_step , model_dir , skip_until = 0 ,
2293- stop_after = None ):
2293+ stop_after = None , find_closest = True ):
22942294 """Get an iterable of checkpoint paths from a provided checkpoint step(s).
22952295
22962296 Args:
2297- checkpoint_step: If checkpoint_step is an int, find the checkpoint with the
2298- closest global step and return a singleton list. If checkpoint_step is a
2299- list of ints, replace each int with the path to the checkpoint with the
2300- closest global step. If checkpoint_step == "all", return the path of every
2301- checkpoint in model_dir, starting from the earliest checkpoint. If
2297+ checkpoint_step: If checkpoint_step is an int, return a singleton list with
2298+ that checkpoint path in it. If find_closest, the checkpoint with the
2299+ closest global step will be reurned. If checkpoint_step is a
2300+ list of ints, replace each int with its corresponding path (if
2301+ find_closest, the path with the closest global step). If
2302+ checkpoint_step == "all", return the path of every checkpoint in
2303+ model_dir, starting from the earliest checkpoint. If
23022304 checkpoint_step == -1, return the latest checkpoint as specified in
23032305 model_dir/checkpoint. If checkpoint_step is None, return
23042306 `tf.train.checkpoints_iterator` for `model_dir`.
@@ -2308,6 +2310,9 @@ def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0,
23082310 stop_after: an optional integer - for "None behavior, if specified
23092311 stop after finding a checkpoint number that is >= stop_at. When a
23102312 checkpoint number == stop_at is found, it is yielded before exiting.
2313+ find_closest: If True and a specified checkpoint step does not exist, will
2314+ choose the nearest checkpoint to that step. If False, then will
2315+ only look for a checkpoint matching the exact specified step.
23112316
23122317 Returns:
23132318 An iterable which yields checkpoint paths.
@@ -2338,6 +2343,10 @@ def _get_closest_checkpoint(target_checkpoint):
23382343 def _get_checkpoint_path (step ):
23392344 return os .path .join (model_dir , "model.ckpt-{}" .format (step ))
23402345
2346+ def _get_checkpoint_path_if_exists (step ):
2347+ path = _get_checkpoint_path (step )
2348+ return path if tf .train .checkpoint_exists (path ) else None
2349+
23412350 def _filter_fn (p ):
23422351 return get_step_from_checkpoint_path (p ) > skip_until
23432352
@@ -2363,11 +2372,22 @@ def _generate_checkpoints():
23632372 return _generate_checkpoints ()
23642373 else :
23652374 return checkpoints_iterator
2366- elif isinstance (checkpoint_step , int ):
2367- return [_get_checkpoint_path (_get_closest_checkpoint (checkpoint_step ))]
2375+ elif find_closest :
2376+ if isinstance (checkpoint_step , int ):
2377+ return [_get_checkpoint_path (_get_closest_checkpoint (checkpoint_step ))]
2378+ else :
2379+ closests = np .unique (
2380+ [_get_closest_checkpoint (c ) for c in checkpoint_step ])
2381+ return [_get_checkpoint_path (closest ) for closest in closests ]
23682382 else :
2369- closests = np .unique ([_get_closest_checkpoint (c ) for c in checkpoint_step ])
2370- return [_get_checkpoint_path (closest ) for closest in closests ]
2383+ if isinstance (checkpoint_step , int ):
2384+ checkpoint_step = [checkpoint_step ]
2385+ checkpoints = [_get_checkpoint_path_if_exists (c ) for c in checkpoint_step ]
2386+ checkpoints = [c for c in checkpoints if c ]
2387+ if not checkpoints :
2388+ raise ValueError ("You asked for checkpoints '%s' but none were found." %
2389+ str (checkpoint_step ))
2390+ return checkpoints
23712391
23722392
23732393# TODO(noam): provide a more informative string for layout_rules:
0 commit comments